Matthew Talluto
04.11.2021
dnormdnorm#' params: named vector of parameters
#' data: list or data frame of all data
log_liklihood = function(params, data) {
# unpack params and data
a = params['a']
b = params['b']
s = params['s']
x = data[['x']]
y = data[['y']]
# compute mu, the expectation of y
mu = a + b*x
# liklihood of y|x,a,b,s
sum(dnorm(y, mu, s, log=TRUE))
}d****()
#' params: named vector of parameters
log_prior = function(params) {
## here the prior hyperparameters are hard-coded
dnorm(params['a'], 0, 10, log=TRUE) +
dnorm(params['b'], 0, 5, log=TRUE) +
dexp(params['s'], 0.2, log=TRUE)
}
#' params: named vector of parameters
#' data: list or data frame of all data
log_posterior = function(params, data) {
# s must be positive; if we try an invalid value
# we have to return something sensible
# probability is 0, so log probability is -Inf
if(params['s'] <= 0)
return(-Inf)
log_liklihood(params, data) + log_prior(params)
}data('iris')
iris = iris[iris$Species != "setosa",]
data = with(iris, data.frame(x = Sepal.Length, y = Petal.Length))
param_init = c(a=0, b=0, s=1)
mod_map = optim(param_init, log_posterior, data = data,
method="Nelder-Mead", control=list(fnscale=-1))
mod_lm = lm(y~x, data=data)
mod_map$par
## a b s
## -1.5516 1.0312 0.4598
coef(mod_lm)
## (Intercept) x
## -1.556 1.032
sd(mod_lm$residuals)
## [1] 0.4623
fit$par gives us the mean/best-fit for each parametervcv tells us the variance-covariance matrix of the multivariate normal distributionlog_posterior_la = function(params, data) {
## s must be positive, but optim doesn't know this.
## We also want to optimize a parameter that is normal
## working with log(s) can solve both problems
params['s'] = exp(params['log_s'])
log_liklihood(params, data) + log_prior(params)
}
param_init = c(a=0, b=0, log_s=0)
fit = optim(param_init, log_posterior_la, data = data,
method="Nelder-Mead",
control=list(fnscale=-1), hessian = TRUE)
vcv = solve(-fit$hessian)
fit$par
## a b log_s
## -1.5508 1.0311 -0.7769
vcv
## a b log_s
## a 1.923e-01 -3.038e-02 4.855e-05
## b -3.038e-02 4.851e-03 -7.420e-06
## log_s 4.855e-05 -7.420e-06 4.994e-03
sds = sqrt(diag(vcv))
sds
## a b log_s
## 0.43856 0.06965 0.07067